import torch
import time
from logging import getLogger
import numpy as np
from PFSPEnv import PFSPEnv as Env
from PFSPModel import PFSPModel as Model
import numpy as np
from utils import get_result_folder, AverageMeter, TimeEstimator
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )

class PFSPTester:
    def __init__(self,
                 env_params,
                 model_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()

        self.n_jobs = self.env_params['job_cnt']
        self.n_mc = self.env_params['mc_cnt']
        self.mode = self.env_params['mode']
        self.pomo_size = self.env_params['pomo_size']
        self.latent_cont_dim = self.model_params['latent_cont_size']
        self.latent_disc_dim = self.model_params['latent_disc_size']

        # Test size
        self.test_size = self.tester_params['test_episodes']
        self.test_batch_size = self.tester_params['test_batch_size']
        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        # ENV and MODEL
        self.env = Env(**self.env_params)
        self.model = Model(**self.model_params)

        # Restore
        model_load = self.tester_params['model_load']
        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
        checkpoint = torch.load(checkpoint_fullname, map_location=device)
        #self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.load_state_dict(checkpoint)
        self.logger.info('...Load Pre-trained model...')

        # utility
        self.time_estimator = TimeEstimator()
        data_load = np.load(f'./Benchmark/tai{self.n_jobs}x{self.n_mc}.npy')
        self.test_dataset = torch.Tensor(data_load)
        self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False, generator=torch.Generator(device=device))

    def run(self):
        self.time_estimator.reset()
        score_AM = AverageMeter()
        no_aug_AM = AverageMeter()
        best_score_list = list()
        epoch=1
        inference_start_t = time.time()
        ms_ub = torch.Tensor(np.load(f'./Benchmark/tai{self.n_jobs}x{self.n_mc}_ub.npy'))

        for problems_batched in self.test_dataloader:
            if self.tester_params['augmentation_enable']:
                aug_factor = self.tester_params['aug_factor']
                ori_batch_size = problems_batched.size(0)
                batch_size = aug_factor*problems_batched.size(0)
                problems_batched = problems_batched.repeat(aug_factor, 1, 1)
                latent_c_var = torch.empty(ori_batch_size, self.env.pomo_size, self.latent_cont_dim ).uniform_(-1, 1)

                latent_d_var = torch.zeros((ori_batch_size, self.env.pomo_size, self.latent_disc_dim), dtype=torch.float32)
                one_hot_idx = torch.randint(0, self.latent_disc_dim, (ori_batch_size, self.env.pomo_size), dtype=torch.long)
                latent_d_var[torch.arange(ori_batch_size).unsqueeze(1), torch.arange(self.env.pomo_size).unsqueeze(0), one_hot_idx] = 1

                latent_var = torch.cat([latent_d_var, latent_c_var], dim=-1)
                latent_var = latent_var.repeat(aug_factor, 1, 1)
            else:
                batch_size = self.test_batch_size
                latent_c_var = torch.empty(batch_size, self.env.pomo_size, self.latent_cont_dim ).uniform_(-1, 1)

                latent_d_var = torch.zeros((batch_size, self.env.pomo_size, self.latent_disc_dim), dtype=torch.float32)
                one_hot_idx = torch.randint(0, self.latent_disc_dim, (batch_size, self.env.pomo_size), dtype=torch.long)
                latent_d_var[torch.arange(batch_size).unsqueeze(1), torch.arange(self.env.pomo_size).unsqueeze(0), one_hot_idx] = 1

                latent_var = torch.cat([latent_d_var, latent_c_var], dim=-1)
                aug_factor = 1
                
            self.model.eval()
            with torch.no_grad():
                self.env.load_problems_manual(problems_batched)
                reset_state, _, _ = self.env.reset()
                selected_list = torch.zeros(size=(batch_size, self.env.pomo_size, 0), dtype=torch.long)

                self.model.pre_forward(reset_state, latent_var)

                state, reward, done = self.env.pre_step()
                while not done:
                    selected, _ = self.model(state)
                    selected_list = torch.cat((selected_list, selected[:, :, None]), dim=2)
                    # shape: (batch, pomo)
                    state, reward, done = self.env.step(selected)

                batch_size = batch_size//aug_factor
                aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)

                max_pomo_reward, _ = aug_reward.max(dim=2)

                no_aug_score = -max_pomo_reward[0, :].float().mean() 
                no_aug_Gap = (-max_pomo_reward[0, :].unsqueeze(0)-ms_ub)/ms_ub*100
                no_aug_Gap = no_aug_Gap.mean()

                max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) 
                aug_score = -max_aug_pomo_reward.float().mean()

                aug_Gap = (-max_aug_pomo_reward.unsqueeze(0)-ms_ub)/ms_ub*100
                aug_Gap = aug_Gap.mean()

                score_AM.update(aug_score.item(), self.test_batch_size)
                no_aug_AM.update(no_aug_score.item(), self.test_batch_size)

                elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.tester_params['epochs'])
                self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}]".format(epoch, self.tester_params['epochs'], elapsed_time_str, remain_time_str))
                epoch+=1

        self.logger.info(" *** Test Done *** ")
        self.logger.info(" Inference Time(s): {:.4f}s".format(time.time()-inference_start_t))
        self.logger.info(" *** Makespan *** ")
        self.logger.info(" Aug Test SCORE: {:.4f} ".format(score_AM.avg))
        self.logger.info(" No Aug Test SCORE: {:.4f} ".format(no_aug_AM.avg))
        self.logger.info(" *** Gap *** ")
        self.logger.info(" Aug Test Gap: {:.4f} ".format(aug_Gap.item()))
        self.logger.info(" No Aug Test Gap: {:.4f} ".format(no_aug_Gap.item()))